import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from model.criterion.quantile_loss import QuantileLoss
from torchcp.regression.utils import build_regression_model
from model.trainer.score.BaseScore import BaseScore

class CQR(BaseScore):
    """
    Conformalized Quantile Regression (CQR)
    This score function allows for calculating scores and generating prediction intervals
    using quantile regression model.

    Reference:
        Paper: Conformalized Quantile Regression (Romano et al., 2019)
        Link: https://proceedings.neurips.cc/paper_files/paper/2019/file/5103c3584b063c431bd1268e9b5e76fb-Paper.pdf
        Github: https://github.com/yromano/cqr
    """

    def __init__(self):
        super().__init__()

    def __call__(self, predicts, y_truth):
       
        if len(predicts.shape) != len(y_truth.shape):
            output_dim = 1
            quantile_num = predicts.shape[-1]
            y_truth = y_truth.unsqueeze(-1)
        else:
            output_dim = y_truth.shape[-1]
            quantile_num = predicts.shape[-1] // output_dim

        assert predicts.shape[-1] == quantile_num * output_dim, (
            f"Last dimension size ({predicts.shape[-1]}) does not match "
            f"{quantile_num} * {output_dim} = {quantile_num * output_dim}"
        )

        # Reshape predicts to separate quantiles and output dimensions
        new_shape = predicts.shape[:-1] + (quantile_num, output_dim)
        predicts = predicts.view(new_shape)  # Shape: [batch_size, quantile_num, output_dim]

        # Extract the lower and upper bounds of the predicted intervals
        lower_quantiles = predicts[..., 0, :]  # Shape: [batch_size, output_dim]
        upper_quantiles = predicts[..., -1, :]  # Shape: [batch_size, output_dim]

        # Calculate conformity scores
        scores = torch.maximum(lower_quantiles - y_truth, y_truth - upper_quantiles)  # Shape: [batch_size, output_dim]

        return torch.clamp(scores, min=0)

    def generate_intervals(self, predicts_batch, q_hat):
    
        if type(predicts_batch) == np.ndarray:
            predicts_batch = torch.from_numpy(predicts_batch).to(q_hat.device)
            predicts_batch = predicts_batch.unsqueeze(dim=0)

        if len(q_hat.shape) != 1:
            output_dim = q_hat.shape[-1]
        else:
            output_dim = q_hat.shape[0]
            q_hat = q_hat.unsqueeze(0)

        if len(predicts_batch.shape) == 3:
            prediction_intervals = predicts_batch.new_zeros((
                predicts_batch.shape[0], q_hat.shape[0], q_hat.shape[-1], 2))
            prediction_intervals[..., 0] = predicts_batch[..., 0:output_dim] - q_hat.view(1, q_hat.shape[0], q_hat.shape[1])
            prediction_intervals[..., 1] = predicts_batch[..., -output_dim:] + q_hat.view(1, q_hat.shape[0], q_hat.shape[1])

        elif len(predicts_batch.shape) == 2:
            prediction_intervals = predicts_batch.new_zeros((predicts_batch.shape[0], q_hat.shape[-1], 2))
            prediction_intervals[..., 0] = predicts_batch[..., 0:output_dim] - q_hat
            prediction_intervals[..., 1] = predicts_batch[..., -output_dim:] + q_hat.view(1,  output_dim)
        else:
            prediction_intervals = predicts_batch.new_zeros((predicts_batch.shape[0], 2))
            prediction_intervals[..., 0] = predicts_batch[..., 0:output_dim] - q_hat
            prediction_intervals[..., 1] = predicts_batch[..., -output_dim:] + q_hat

        return prediction_intervals

    def train(self, train_dataloader, **kwargs):
       
        device = kwargs.get('device', None)
        model = kwargs.get('model',
                           build_regression_model("NonLinearNet")(next(iter(train_dataloader))[0].shape[1], 2, 64,
                                                                  0.5).to(device))
        criterion = kwargs.get('criterion', None)

        if criterion is None:
            alpha = kwargs.get('alpha', None)
            if alpha is None:
                raise ValueError("When 'criterion' is not provided, 'alpha' must be specified.")
            quantiles = [alpha / 2, 1 - alpha / 2]
            criterion = QuantileLoss(quantiles)

        epochs = kwargs.get('epochs', 100)
        lr = kwargs.get('lr', 0.01)
        optimizer = kwargs.get('optimizer', optim.Adam(model.parameters(), lr=lr))
        verbose = kwargs.get('verbose', True)

        self._basetrain(model, epochs, train_dataloader, criterion, optimizer, verbose)
        return model
